

from __future__ import annotations

import argparse
import json
import os
from pathlib import Path
from typing import Dict, List

# ---------------------------------------------------------------------------- #
# Emotion vocabulary & mapping helpers                                         #
# ---------------------------------------------------------------------------- #
# Persuasion strategies vocabulary (replaces prior EMOTION_VOCAB)
topics = """Persuasion Strategies Vocabulary: { 'Authority':'Authority indicated through expertise, source of power, third-party approval, credentials, and awards','Social Identity':'Normative influence, which involves conformity with the positive expectations of 'another', who could be another person, a group, or ones self.using the idea of 'everyone else is doing it' to influence people's behavior.', 'Social Proof':'efers to the use of testimonials, reviews, or other forms of social validation to demonstrate the popularity, trustworthiness, or quality of a product or brand. By leveraging social proof, advertisements can increase consumers' confidence and trust in the product or brand, and encourage them to make a purchase.','Reciprocity':'By obligating the recipient of an act to repayment in the future, the rule for reciprocation begets a sense of future obligation, often unequal in nature','Foot in the door':'Starting with small requests followed by larger requests to facilitate compliance while maintaining cognitive coherence.','Overcoming Reactance':'Overcoming resistance (reactance) by postponing consequences to the future, by focusing resistance on realistic concerns, by forewarning that a message will be coming, by acknowledging resistance, by raising self-esteem and a sense of efficacy.','Concreteness':'concreteness refers to the use of specific, tangible details or examples to make an abstract or complex concept more concrete and relatable to consumers. By using concrete language and imagery, advertisements can increase consumers' understanding and engagement with the product or brand, and create a more vivid and memorable impression.','Anchoring and Comparison':'anchoring refers to the use of a reference point or starting point to influence consumers' perceptions of value or price. Comparison refers to the use of side-by-side or direct comparisons to demonstrate the superiority of a product or brand over competitors. Both anchoring and comparison are common persuasion strategies used in advertising to influence consumer decision-making.','Social Impact':'Refers to the positive effect that an advertisement has on society or the broader community. This can include promoting social causes, raising awareness about important issues, or encouraging positive behaviors and attitudes.','Scarcity':'People assign more value to opportunities when they are less available. This happens due to psychological reactance of losing freedom of choice when things are less available or they use availability as a cognitive shortcut for gauging quality.','Unclear':'If the strategy used in the advertisement is unclear or it is not in English or no strategy is used as the central message of the advertisement'}"""

# List of persuasion strategy keys for reference and validation
PERSUASION_STRATEGIES = [
    "Authority",
    "Social Identity",
    "Social Proof",
    "Reciprocity",
    "Foot in the door",
    "Overcoming Reactance",
    "Concreteness",
    "Anchoring and Comparison",
    "Social Impact",
    "Scarcity",
    "Unclear",
]


# ---------------------------------------------------------------------------- #
# Normalisation helpers                                                       #
# ---------------------------------------------------------------------------- #

import re

_SPACE_RE = re.compile(r"\s+")

# Build a mapping from normalized strategy names to canonical keys
NORMALIZED_STRATEGY_MAP = {}
for key in PERSUASION_STRATEGIES:
    norm = re.sub(r"\s+", " ", key.lower().strip().strip("'\". ,;:!-_"))
    NORMALIZED_STRATEGY_MAP[norm] = key

def normalize_label(label: str) -> str:
    """Lower-case, collapse whitespace, strip punctuation, and map to canonical persuasion strategy if possible."""
    if label is None:
        return ""
    lab = label.lower().strip().strip("'\". ,;:!-_")
    lab = _SPACE_RE.sub(" ", lab)
    # Try to map to canonical key
    if lab in NORMALIZED_STRATEGY_MAP:
        return NORMALIZED_STRATEGY_MAP[lab]
    # Try partial match (for common errors)
    for norm, canon in NORMALIZED_STRATEGY_MAP.items():
        if lab == norm or lab in norm or norm in lab:
            return canon
    return lab


# ---------------------------------------------------------------------------- #
# Utility functions                                                             #
# ---------------------------------------------------------------------------- #


def load_ground_truth(path: Path) -> Dict[str, set[str]]:
    """Load ground-truth mapping *video_id* -> *set(strategies)* (normalised)."""
    with path.open("r", encoding="utf-8") as f:
        raw: Dict[str, List[str]] = json.load(f)

    gt: Dict[str, set[str]] = {}
    for vid, strategies in raw.items():
        if isinstance(strategies, str):
            strategies_norm = {normalize_label(strategies)}
        elif isinstance(strategies, list):
            strategies_norm = {normalize_label(s) for s in strategies if isinstance(s, str)}
        else:
            raise ValueError(f"Unexpected ground-truth value for video {vid}: {type(strategies).__name__}")

        if not strategies_norm:
            raise ValueError(f"No valid strategies provided for video {vid} in ground truth.")

        gt[vid] = strategies_norm

    return gt


def load_predictions(pred_dir: Path) -> List[tuple[str, str]]:
    """Load predictions from **all** *.json* files inside *pred_dir*.

    Returns a list of *(video_id, emotion)* pairs **including duplicates** so that
    we can compute accuracy both with and without deduplication.
    """
    records_all: List[tuple[str, str]] = []

    for json_path in pred_dir.glob("*.json"):
        try:
            with json_path.open("r", encoding="utf-8") as f:
                data = json.load(f)
        except json.JSONDecodeError as e:
            print(f"[WARN] Failed to parse {json_path.name}: {e}")
            continue

        # Some prediction files are a single object, others are a list of objects.
        if isinstance(data, list):
            recs_in_file = data
        else:
            recs_in_file = [data]

        for rec in recs_in_file:
            if not isinstance(rec, dict):
                print(f"[WARN] Unexpected record type in {json_path.name}: {type(rec).__name__}; skipping.")
                continue

            video_id: str | None = rec.get("video_id")
            if not video_id:
                video_id = json_path.stem  # fallback

            # accept either key name
            strategy: str | None = rec.get("final_topic") or rec.get("predicted_topic")
            if not strategy:
                print(f"[WARN] Missing strategy prediction for video {video_id} in {json_path.name}; skipping.")
                continue

            records_all.append((video_id, strategy))

    return records_all


# ------------------------------------------------------------------------- #
# Accuracy helpers
# ------------------------------------------------------------------------- #


def compute_accuracy_records(records: List[tuple[str, str]], gt: Dict[str, set[str]]) -> tuple[int, int]:
    """Compute (correct, total) for a list of prediction records (may contain duplicates)."""
    correct = 0
    for vid, pred_raw in records:
        gt_set = gt.get(vid)
        if not gt_set:
            continue  # unknown id in GT
        if normalize_label(pred_raw) in gt_set:
            correct += 1
    return correct, len(records)


def compute_accuracy_unique(pred_unique: Dict[str, str], gt: Dict[str, set[str]]) -> tuple[int, int]:
    """Compute (correct, total) using a mapping of unique video_id -> strategy."""
    correct = 0
    total = 0
    for vid, pred_raw in pred_unique.items():
        gt_set = gt.get(vid)
        if gt_set is None:
            continue  # unknown id
        total += 1
        if normalize_label(pred_raw) in gt_set:
            correct += 1
    return correct, total


def main() -> None:
    parser = argparse.ArgumentParser(description="Evaluate persuasion-strategy prediction accuracy.")
    parser.add_argument("--pred_dir", type=str, required=True, help="Directory containing prediction JSON files.")
    parser.add_argument("--annot_file", type=str, required=True, help="Path to emotion_annotation.json ground-truth file.")
    parser.add_argument("--output", type=str, default="np_metric.txt", help="File to write accuracy to (default: metric.txt)")
    args = parser.parse_args()

    pred_dir = Path(args.pred_dir)
    annot_path = Path(args.annot_file)
    out_path = Path(args.output)

    if not pred_dir.is_dir():
        raise NotADirectoryError(f"Prediction directory not found: {pred_dir}")
    if not annot_path.is_file():
        raise FileNotFoundError(f"Annotation file not found: {annot_path}")

    # Load ground truth once
    gt_map = load_ground_truth(annot_path)

    def load_predictions_file(path: Path) -> List[tuple[str, str]]:
        """Load predictions only from the given JSON file path."""
        try:
            with path.open("r", encoding="utf-8") as f:
                data = json.load(f)
        except Exception as e:
            print(f"[WARN] Could not read {path.name}: {e}")
            return []

        recs = data if isinstance(data, list) else [data]
        pairs: List[tuple[str, str]] = []
        for rec in recs:
            if not isinstance(rec, dict):
                continue
            vid = rec.get("video_id") or path.stem
            strat = rec.get("final_topic") or rec.get("predicted_topic")
            if not strat:
                continue
            pairs.append((vid, strat))
        return pairs

    # Evaluate each JSON file separately
    with out_path.open("w", encoding="utf-8") as fout:
        for json_path in sorted(pred_dir.glob("*.json")):
            pred_records = load_predictions_file(json_path)

            # duplicates accuracy
            c_dup, t_dup = compute_accuracy_records(pred_records, gt_map)
            acc_dup = (c_dup / t_dup) if t_dup else 0.0

            # unique accuracy
            pred_unique: Dict[str, str] = {}
            for vid, strat in pred_records:
                pred_unique[vid] = strat  # last one wins

            c_u, t_u = compute_accuracy_unique(pred_unique, gt_map)
            acc_u = (c_u / t_u) if t_u else 0.0

            line = f"{json_path.name}: dup {acc_dup:.4f} | unique {acc_u:.4f}"
            print(line)
            fout.write(line + "\n")


if __name__ == "__main__":
    main()
